[YOLOv8] 生産ラインを流れるアヒルを追跡して数をかぞえてみました

[YOLOv8] 生産ラインを流れるアヒルを追跡して数をかぞえてみました

Clock Icon2024.07.15

1 はじめに

製造ビジネステクノロジー部の平内(SIN)です。

Ultralytics 社の YOLOv8は、最先端、高速、正確で非常に使いやすく設計された物体検出モデルです。

001

YOLOv8は、さまざまなオブジェクトの検出、インスタンスのセグメンテーション、画像分類、ポーズ推定などを処理することが可能ですが、トラックング(追跡)タスクについても対応しています。

002

今回は、このオブジェクト検出及び、トラッキングを使用して、生産ラインを流れるアヒルをカウントしてみました。

最初に、作成したデモをご確認下さい。

https://www.youtube.com/watch?v=-cKSeygprDQ

2 物体検出(ファインチューニング)

最初に、アヒルを検出するためモデル作成します。

YOLOv8のファインチューニングは、非常に簡単で、形式通りのデータを準備してmodel.train()を実行するだけです。

from ultralytics import YOLO

model = YOLO("yolov8l.pt")
model.train(data="dataset.yaml", epochs=15, batch=8, workers=4, degrees=90.0)

データセットの配置は、dataset.yamlで指定します。

dataset.yaml

train: /home/dataset/yolo/train/images
val: /home/dataset/yolo/valid/images

nc: 1
names: ["ahiru"]

実際のデータの配置は、以下のとおりです。

.
├── train
│   ├── images
│   │   ├── 00600.png
│   │   ├── 00601.png
・・・略・・・
│   │   ├── 02998.png
│   │   └── 02999.png
│   ├── labels
│   │   ├── 00600.txt
│   │   ├── 00601.txt
・・・略・・・
│   │   ├── 02998.txt
│   │   └── 02999.txt
│   └── labels.cache
└── valid
    ├── images
    │   ├── 00000.png
    │   ├── 00001.png
・・・略・・・
    │   ├── 00598.png
    │   └── 00599.png
    ├── labels
    │   ├── 00000.txt
    │   ├── 00001.txt
・・・略・・・
    │   ├── 00598.txt
    │   └── 00599.txt
    └── labels.cache

6 directories, 6002 files

データセットは、対象オブジェクトを撮影し、Segment Anything Modelで切り出して、プログラムによる合成により大量生産しています。(今回は、3,000枚の画像と約24,000個のアノテーション)

カメラによる撮影

005

006

データセット生成

007

参考:

https://dev.classmethod.jp/articles/sygment-anything-create-dataset-image/

https://dev.classmethod.jp/articles/yolov5-nvidia-jetson-agx-orin/

トレーニングの状況です。

__
004

003

epoch,    train/box_loss,     train/cls_loss, metrics/mAP50(B),    metrics/mAP50-95(B),r/pg2
    1,            1.0073,            0.62287,          0.92468,                0.42867,66444
    2,           0.88396,            0.49335,          0.87443,                0.57074,12433
    3,           0.82376,            0.46769,          0.99497,                0.70049,17341
    4,           0.75457,            0.42457,          0.99496,                0.74066,01604
    5,           0.68971,            0.39211,          0.98101,                0.71505,01472
    6,           0.56326,            0.28422,          0.93799,                0.72821,00134
    7,           0.53406,            0.26958,            0.995,                0.80915,01208
    8,           0.49782,            0.25346,          0.99287,                0.89071,01076
    9,           0.45277,            0.23768,            0.995,                0.86261,00944
   10,           0.41017,            0.21778,            0.995,                0.91045,00812
   11,            0.3922,            0.21385,            0.995,                0.94417,00068
   12,           0.35645,            0.19367,            0.995,                0.91055,00548
   13,           0.33934,            0.18195,            0.995,                0.92263,00416
   14,           0.31723,            0.17222,            0.995,                0.96464,00284
   15,            0.2961,            0.16354,            0.995,                0.97386,00152

3 トラッキング

作成したモデルでトラッキングしているコードです。

検出したオブジェクトのIDを確認し、画面の右に位置する時にリストアップしておき、中央を超えた時点で、カウンターをアップしています。

import cv2
from ultralytics import YOLO

COLORS = [
    (255, 80, 0),
    (255, 255, 0),
    (255, 80, 100),
    (255, 80, 255),
    (255, 120, 255),
    (155, 255, 255),
    (155, 155, 255),
    (155, 200, 200),
    (155, 80, 155),
    (200, 200, 200),
]

model = YOLO("./runs/detect/train4/weights/best.pt")

class Counter:

    before_counting_id_list = []  # 画面の右側にあるカウントする前のIDリスト

    def __init__(self, w, h):
        self.counter = 0
        self.w = w
        self.h = h

    def set(self, id, box):
        # 対象オブジェクトのX座標の中心を取得
        x1 = box[0]
        x2 = box[2]
        center = int(x1 + (x2 - x1) / 2)

        # 画面の中央より右側にある場合
        if center > int(self.w / 2):
            # まだ、リストの損じしない場合、リストに追加
            if not id in self.before_counting_id_list:
                self.before_counting_id_list.append(id)

        # 画面の中央より左側にある場合
        if center < int(self.w / 2):
            # IDにあれば、リストを削除してカウントアップする
            if id in self.before_counting_id_list:
                self.counter += 1
                self.before_counting_id_list.remove(id)

    def disp_counter(self, frame):
        # 中央のラインを描画
        cv2.line(
            frame, (int(self.w / 2), 70), (int(self.w / 2), self.h - 20), (0, 0, 255), 2
        )
        # カウンターを描画
        cv2.putText(
            frame,
            "COUNTER: {}".format(self.counter),
            (230, 30),
            cv2.FONT_HERSHEY_SIMPLEX,
            1,
            (0, 0, 255),
            3,
        )

# 検出したオブジェクトにラベルを表示する関数
def disp_label(frame, box, id):
    color = COLORS[id % 10]
    x1 = int(box[0])
    y1 = int(box[1])
    x2 = int(box[2])
    y2 = int(box[3])

    cv2.rectangle(
        frame,
        (x1, y1),
        (x2, y2),
        color,
        2,
    )

    cv2.putText(
        frame,
        "ID:{}".format(id),
        (x1, y1 - 15),
        cv2.FONT_HERSHEY_SIMPLEX,
        1,
        color,
        3,
    )

def main() -> int:

    cap = cv2.VideoCapture(0)
    cap.set(cv2.CAP_PROP_FPS, 5)  # 強制的にFLSを下げる

    w, h, fps = (
        int(cap.get(x))
        for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS)
    )
    print("w:{} h:{} fps:{}".format(w, h, fps))

    counter = Counter(w, h)

    while True:
        ret, frame = cap.read()
        if not ret:
            print("Video ERROR")
            break

        # 検出したオブジェクトを取得
        results = model.track(frame, persist=True, conf=0.3)

        # 検出した個々のオブジェクトを処理する
        for box in results[0].boxes:
            r = box.xyxy.tolist()
            # トラッキングIDを取得
            id = int(box.id) if box.id is not None else 0
            # カウント処理
            counter.set(id, r[0])
            disp_label(frame, r[0], id)

        # カウント表示
        counter.disp_counter(frame)

        frame = cv2.resize(frame, dsize=None, fx=1.5, fy=1.5)
        cv2.imshow("frame", frame)

        if cv2.waitKey(1) & 0xFF == ord("q"):
            break

    cap.release()
    cv2.destroyAllWindows()

if __name__ == "__main__":
    main()

4 最後に

YOLOv8で物体検出すると、検出されたオブジェクトの固有IDが取得できるため、これを使用することで、各種のトラッキングタスクが処理可能になります。

ただし、追跡のためには、対象オブジェクトを駒落ちせずに確実に検出する必要があるため、物体検出モデルの精度は、比較的高いものが要求されます。

SAMを使用したデータセット作成では、非常に質の高い大量のデータが簡単に生成できるため、これを可能にしているとも言えそうです。

5 参考リンク

https://dev.classmethod.jp/articles/yolov8-trial-custom-dataset/

https://github.com/RizwanMunawar/yolov8-object-tracking?tab=readme-ov-file

https://muhammadrizwanmunawar.medium.com/ultralytics-yolov8-object-trackers-botsort-vs-bytetrack-comparison-d32d5c82ebf3

https://contra.com/p/A7XcWY6Z-yolo-v10-object-detection-and-tracking

Share this article

facebook logohatena logotwitter logo

© Classmethod, Inc. All rights reserved.